﻿namespace Tensorflow.Keras.Losses;

public class LogCosh : LossFunctionWrapper
{
    public LogCosh(
        string reduction = null,
        string name = null) :
        base(reduction: reduction, name: name == null ? "log_cosh" : name)
    { }

    public override Tensor Apply(Tensor y_true = null, Tensor y_pred = null, bool from_logits = false, int axis = -1)
    {
        Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
        Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
        Tensor x = y_pred_dispatch - y_true_cast;

        return gen_math_ops.mean(x + gen_nn_ops.softplus(-2.0 * x) - math_ops.cast(math_ops.log(tf.Variable(2.0)), x.dtype),
            ops.convert_to_tensor(-1));
    }
}